{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f6cc43a4",
   "metadata": {},
   "source": [
    "# Out-of-DIstribution detector for Neural networks (ODIN) with mobilenet\n",
    "\n",
    "The goal of this notebook is to understand [Enhancing The Reliability of Out-of-distribution Image Detection in Neural Networks](https://arxiv.org/abs/1706.02690). We transform a trained classifier into a out-of-distribution detector. Here, we will take a small mobilenet neural network trained on CIFAR10 and see how to use it to make a detector for out-of-distirbution samples obtianed by cropping images from the ImageNet dataset.\n",
    "\n",
    "The image below is taken from the original paper and shows the performances obtained with the ODIN method on DenseNet-BC-100 network.\n",
    "![](https://raw.githubusercontent.com/ShiyuLiang/odin-pytorch/master/figures/original_optimal_shade.png)\n",
    "\n",
    "A detector is a binary classifier thath needs to output $1$ when the input image is from the in-distribution (CIFAR10) and $0$ when the input image is from the out-distribution (ImageNet-crop).\n",
    "TPR on CIFAR10 is then the probability that the detector is correct on an image from CIFAR10 and FPR on TinyImagenNet (crop) is the probability that the detector is not correct on an image from ImageNet-crop.\n",
    "As in binary classification, we aim at maximizing TPR while minimizing FPR and the tradeoff between these two goals is capured by a ROC curve (shown above).\n",
    "\n",
    "In this practical, we derive ROC curves for the ODIN methods with the mobilenet network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54e18c02",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "import sys\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "from sklearn import metrics\n",
    "\n",
    "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45a1537d",
   "metadata": {},
   "source": [
    "# Mobilenet for CIFAR10\n",
    "\n",
    "We will use a pretrained model called [Mobilenet v2](https://pytorch.org/hub/pytorch_vision_mobilenet_v2/) which is an efficient network optimized for speed and memory, with [residual blocks](https://dataflowr.github.io/website/modules/17-resnets/). The default version from `torch.vision` is pretrained on Imagenet so we will rely on the version provided by [PyTorch_CIFAR10](https://github.com/huyvnphan/PyTorch_CIFAR10) which was trained on CIFAR10.\n",
    "\n",
    "The code below allows you to download the weights of the neural network and the [piece of code](https://github.com/dataflowr/notebooks/blob/master/Module17/mobilenetv2.py) that will be useful for us (without cloning the whole repo above)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb5d4d3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# the code below downloads the weights of the neural net and should be run only once.\n",
    "# You can uncomment it and run it, the first time you are using it or if you are running on colab.\n",
    "#%mkdir state_dicts\n",
    "#%cd state_dicts/\n",
    "#!wget https://www.di.ens.fr/~lelarge/mobilenet_v2.pt\n",
    "#%cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c81b932e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# if on colab first uncomment and run the following command (only once):\n",
    "#!wget https://raw.githubusercontent.com/dataflowr/notebooks/master/Module17/mobilenetv2.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1be80f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mobilenetv2 import mobilenet_v2\n",
    "\n",
    "model = mobilenet_v2(pretrained=True)\n",
    "model.eval()\n",
    "model.to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1e678a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "NORM_BIAS = [125.3 / 255, 123.0 / 255, 113.9 / 255]\n",
    "NORM_SCALE = [63.0 / 255, 62.1 / 255.0, 66.7 / 255.0]\n",
    "EPS_FSGM = 1e-2\n",
    "IMAGE_SIZE = (32, 32)\n",
    "batch_size = 64\n",
    "\n",
    "%mkdir data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19c295c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.Compose(\n",
    "    [\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(NORM_BIAS, NORM_SCALE),\n",
    "        transforms.Resize(IMAGE_SIZE),\n",
    "    ]\n",
    ")\n",
    "\n",
    "testset = torchvision.datasets.CIFAR10(root=\"./data/\", train=False, download=True, transform=transform)\n",
    "testloaderIn = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4489735",
   "metadata": {},
   "source": [
    "## Question 1: test your network\n",
    "\n",
    "Use the code from the [course](https://dataflowr.github.io/website/modules/5-stacking-layers/) (when we did overfit a MLP on CIFAR10) in order to get the performance of the pretrained network on CIFAR10 test set.\n",
    "\n",
    "The dataloader is given to you above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d744109",
   "metadata": {},
   "outputs": [],
   "source": [
    "# your code"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07b2e3a3",
   "metadata": {},
   "source": [
    "## Out-of-distribution Dataset\n",
    "\n",
    "The code below allows you to download the out-of-distribution dataset.\n",
    "You should not modify this code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fc96973",
   "metadata": {},
   "outputs": [],
   "source": [
    "def imshow(inp, title=None):\n",
    "#   Imshow for Tensor.\n",
    "    inp = inp.numpy().transpose((1, 2, 0))\n",
    "    inp = np.clip(NORM_SCALE * inp + NORM_BIAS, 0,1)\n",
    "    plt.imshow(inp)\n",
    "    if title is not None:\n",
    "        plt.title(title)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12fcbd26",
   "metadata": {},
   "outputs": [],
   "source": [
    "batchIn = next(iter(testloaderIn))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb04bdef",
   "metadata": {},
   "outputs": [],
   "source": [
    "imgIn = batchIn[0]\n",
    "imgIn = torch.narrow(imgIn, 0, 0, 32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c20e999",
   "metadata": {},
   "outputs": [],
   "source": [
    "im_in = torchvision.utils.make_grid(imgIn)\n",
    "\n",
    "imshow(im_in,title='images from CIFAR10')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1012280e",
   "metadata": {},
   "outputs": [],
   "source": [
    "CIFAR10_labels = ['airplane', 'automobile','bird','cat','deer','dog','frog','horse','ship','truck']\n",
    "preds = model(imgIn.to(DEVICE)).max(1, keepdim=True)[1]\n",
    "[CIFAR10_labels[pred] for pred in preds]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa323ce6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The code below dowload the out-of-distribution dataset\n",
    "# if on colab first uncomment and run the following command (only once):\n",
    "#%cd data\n",
    "#!wget https://www.dropbox.com/s/avgm2u562itwpkl/Imagenet.tar.gz\n",
    "#!tar -xvzf Imagenet.tar.gz\n",
    "#%cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4cce31c",
   "metadata": {},
   "outputs": [],
   "source": [
    "testsetout = torchvision.datasets.ImageFolder(\"./data/Imagenet/\", transform=transform)\n",
    "testloaderOut = DataLoader(testsetout, batch_size=batch_size, shuffle=False, num_workers=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efd2e8f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "batchOut = next(iter(testloaderOut))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1eed35b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "imgOut = batchOut[0]\n",
    "imgOut = torch.narrow(imgOut, 0, 0, 32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5533600d",
   "metadata": {},
   "outputs": [],
   "source": [
    "im_out = torchvision.utils.make_grid(imgOut)\n",
    "\n",
    "imshow(im_out, title='images from ImageNet-crop')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1beec9de",
   "metadata": {},
   "source": [
    "We have now two dataloaders `testloaderIn` and `testloaderOut` corresponding to the in- and out-distributio, and ou mobilenet network `model`. \n",
    "\n",
    "# Out-of-distribution Detector\n",
    "\n",
    "Our network `model` has been trained on CIFAR10 and hence outputs a vetor of size 10 representing the logprobs of being one of the 10 classes of CIFAR10. To use our model as a detector, we need to transform this output into a single value corresponding to the probability of being from the in-distribution. Hence, we ant thsi output to be high for images from CIFAR10 and low for images form ImageNet-crop.\n",
    "\n",
    "## Temperature scaling\n",
    "\n",
    "The first method proposed in [ODIN](https://arxiv.org/abs/1706.02690) relies on the fact that the trained network should be more confident on images from the in-distribution than on images form the out-distribution.\n",
    "\n",
    "For an image ${\\bf x}$, let $f({\\bf x}) = (f_1({\\bf x}),\\dots, f_C({\\bf x}))\\in \\mathbb{R}^C$ be the output of the network, where $C$ is the number of classes. We define for a temperature scaling parameter $T>0$,\n",
    "\\begin{eqnarray*}\n",
    "S_i({\\bf x},T) = \\frac{\\exp(f_i({\\bf x})/T)}{\\sum_{j=1}^C \\exp(f_j({\\bf x})/T)}\\geq 0,\n",
    "\\end{eqnarray*}\n",
    "so that $\\sum_i S_i({\\bf x},T)=1$. The **softmax score** is then $\\max_i S_i({\\bf x},T)$.\n",
    "\n",
    "We expect the softmax score to be higher for images from the in-distribution than for images from the out-distribution. Hence we can build a classifier based on this score."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5440cca4",
   "metadata": {},
   "source": [
    "## Question 2: compute the softmax score\n",
    "\n",
    "Compute the softmax scores with $T=1$ and $T=1000$ and plot the corresponding ROC curves (using [`metrics.roc_curve`](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html)).\n",
    "\n",
    "Hint: to get a picture similar to the one in the paper shown above use `plt.ylim(0.8, 1.002)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee031f05",
   "metadata": {},
   "outputs": [],
   "source": [
    "# this function takes a model, an algorithm and a dataloader and compute all the corresponding scores\n",
    "# To use this function your algorithm needs to have a method apply taking as input \n",
    "# a batch and the model in order to compute the scores for the batch\n",
    "def compute_scores(model, algo, loader, device=DEVICE):\n",
    "    model.eval()\n",
    "    all_scores = []\n",
    "    for i, (batch, targets) in enumerate(loader):\n",
    "        bs = batch.shape[0]\n",
    "        batch = batch.to(device)\n",
    "        scores = algo.apply(batch, model)\n",
    "        all_scores += scores\n",
    "    return all_scores\n",
    "# You can test your method apply by running commands like:\n",
    "# name_of_your_algo.apply(imgIn.to(DEVICE), model)\n",
    "# name_of_your_algo.apply(imgOut.to(DEVICE), model)\n",
    "# to obtain the scores on the small batchs made above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1014e3f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# your code"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70581585",
   "metadata": {},
   "source": [
    "## Input preprocessing\n",
    "\n",
    "The second method relies on perturbing the imput image. The main idea is to perturb the image in order to increase its softmax score as follows:\n",
    "\\begin{eqnarray*}\n",
    "\\tilde{\\bf x} = {\\bf x} - \\epsilon \\text{sign}\\left( -\\nabla_{\\bf x} \\log S_{\\hat{y}}({\\bf x},T)\\right),\n",
    "\\end{eqnarray*}\n",
    "where $\\hat{y} =\\arg\\max_i S_i({\\bf x},T)$ and $\\epsilon>0$ is the noise magnitude applied on the image.\n",
    "This perturbation should have a much higher impact on CIFAR10 images (in-distribution) than on ImageNet images (out-distribution) making the detection easier.\n",
    "\n",
    "This second method can be used with the temperature scaling (this is why we keep the parameter $T$ in the formula above). Note that this method does not require the label of the image, as a result it can be applied to in-domain images as well as out-of domain images. The gradient can be computed with a standard [cross-entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) but note that the gradient is taken with respect to the pixels of the image not the weights of the neural network which are untouched. A similar technique has been used in [adversarial attacks](https://dataflowr.github.io/website/homework/2-CAM-adversarial/) where the goal was to perturb the image in order to fool the prediction made by the neural network."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2def5b6f",
   "metadata": {},
   "source": [
    "## Question 3: code the perturbation on the image and the ODIN algorithm with both perturbation and temperature scaling.\n",
    "\n",
    "Hints: \n",
    "- make a clone of the images and make it differentiable: \n",
    "```\n",
    "inputs = images.clone()\n",
    "inputs.requires_grad=True\n",
    "```\n",
    "- compute the label with `argmax`, the loss and do the backpropagation.\n",
    "- compute the sign of the gradient with [`torch.ge`](https://pytorch.org/docs/stable/generated/torch.ge.html#torch-ge) for example to get your binary gradient\n",
    "- remember that the images are first scaled with `NORM_SCALE`, hence divide your binary gradient by these quantity in order to implement exactly the algorithm described above (and to get better results!).\n",
    "- for the ODIN algorithm creat a class as you did above with a `apply` method so that you can reuse your code for computing scores. You can test your algo with commands like\n",
    "```\n",
    "scores_In_odin = odin.apply(imgIn.to(DEVICE), model)\n",
    "scores_Out_odin = odin.apply(imgOut.to(DEVICE), model)\n",
    "```\n",
    "\n",
    "Plot the resulting ROC curves."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c044bcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# your code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08a6b56d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dldiy",
   "language": "python",
   "name": "dldiy"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
